// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause

package eventbustest

import (
	"errors"
	"fmt"
	"reflect"
	"testing"
	"time"

	"tailscale.com/util/eventbus"
)

// NewBus constructs an [eventbus.Bus] that will be shut automatically when
// its controlling test ends.
func NewBus(t *testing.T) *eventbus.Bus {
	bus := eventbus.New()
	t.Cleanup(bus.Close)
	return bus
}

// NewTestWatcher constructs a [Watcher] that can be used to check the stream of
// events generated by code under test. After construction the caller may use
// [Expect] and [ExpectExactly], to verify that the desired events were captured.
func NewWatcher(t *testing.T, bus *eventbus.Bus) *Watcher {
	tw := &Watcher{
		mon:     bus.Debugger().WatchBus(),
		TimeOut: 5 * time.Second,
		chDone:  make(chan bool, 1),
		events:  make(chan any, 100),
	}
	if deadline, ok := t.Deadline(); ok {
		tw.TimeOut = deadline.Sub(time.Now())
	}
	t.Cleanup(tw.done)
	go tw.watch()
	return tw
}

// Watcher monitors and holds events for test expectations.
type Watcher struct {
	mon    *eventbus.Subscriber[eventbus.RoutedEvent]
	events chan any
	chDone chan bool
	// TimeOut defines when the Expect* functions should stop looking for events
	// coming from the Watcher. The value is set by [NewWatcher] and defaults to
	// the deadline passed in by [testing.T]. If looking to verify the absence
	// of an event, the TimeOut can be set to a lower value after creating the
	// Watcher.
	TimeOut time.Duration
}

// Type is a helper representing the expectation to see an event of type T, without
// caring about the content of the event.
// It makes it possible to use helpers like:
//
//	eventbustest.ExpectFilter(tw, eventbustest.Type[EventFoo]())
func Type[T any]() func(T) { return func(T) {} }

// Expect verifies that the given events are a subsequence of the events
// observed by tw. That is, tw must contain at least one event matching the type
// of each argument in the given order, other event types are allowed to occur in
// between without error. The given events are represented by a function
// that must have one of the following forms:
//
//	// Tests for the event type only
//	func(e ExpectedType)
//
//	// Tests for event type and whatever is defined in the body.
//	// If return is false, the test will look for other events of that type
//	// If return is true, the test will look for the next given event
//	// if a list is given
//	func(e ExpectedType) bool
//
//	// Tests for event type and whatever is defined in the body.
//	// The boolean return works as above.
//	// The if error != nil, the test helper will return that error immediately.
//	func(e ExpectedType) (bool, error)
//
// If the list of events must match exactly with no extra events,
// use [ExpectExactly].
func Expect(tw *Watcher, filters ...any) error {
	if len(filters) == 0 {
		return errors.New("no event filters were provided")
	}
	eventCount := 0
	head := 0
	for head < len(filters) {
		eventFunc := eventFilter(filters[head])
		select {
		case event := <-tw.events:
			eventCount++
			if ok, err := eventFunc(event); err != nil {
				return err
			} else if ok {
				head++
			}
		case <-time.After(tw.TimeOut):
			return fmt.Errorf(
				"timed out waiting for event, saw %d events, %d was expected",
				eventCount, head)
		case <-tw.chDone:
			return errors.New("watcher closed while waiting for events")
		}
	}
	return nil
}

// ExpectExactly checks for some number of events showing up on the event bus
// in a given order, returning an error if the events does not match the given list
// exactly. The given events are represented by a function as described in
// [Expect]. Use [Expect] if other events are allowed.
func ExpectExactly(tw *Watcher, filters ...any) error {
	if len(filters) == 0 {
		return errors.New("no event filters were provided")
	}
	eventCount := 0
	for pos, next := range filters {
		eventFunc := eventFilter(next)
		fnType := reflect.TypeOf(next)
		argType := fnType.In(0)
		select {
		case event := <-tw.events:
			eventCount++
			typeEvent := reflect.TypeOf(event)
			if typeEvent != argType {
				return fmt.Errorf(
					"expected event type %s, saw %s, at index %d",
					argType, typeEvent, pos)
			} else if ok, err := eventFunc(event); err != nil {
				return err
			} else if !ok {
				return fmt.Errorf(
					"expected test ok for type %s, at index %d", argType, pos)
			}
		case <-time.After(tw.TimeOut):
			return fmt.Errorf(
				"timed out waiting for event, saw %d events, %d was expected",
				eventCount, pos)
		case <-tw.chDone:
			return errors.New("watcher closed while waiting for events")
		}
	}
	return nil
}

func (tw *Watcher) watch() {
	for {
		select {
		case event := <-tw.mon.Events():
			tw.events <- event.Event
		case <-tw.chDone:
			tw.mon.Close()
			return
		}
	}
}

// done tells the watcher to stop monitoring for new events.
func (tw *Watcher) done() {
	close(tw.chDone)
}

type filter = func(any) (bool, error)

func eventFilter(f any) filter {
	ft := reflect.TypeOf(f)
	if ft.Kind() != reflect.Func {
		panic("filter is not a function")
	} else if ft.NumIn() != 1 {
		panic(fmt.Sprintf("function takes %d arguments, want 1", ft.NumIn()))
	}
	var fixup func([]reflect.Value) []reflect.Value
	switch ft.NumOut() {
	case 0:
		fixup = func([]reflect.Value) []reflect.Value {
			return []reflect.Value{reflect.ValueOf(true), reflect.Zero(reflect.TypeFor[error]())}
		}
	case 1:
		if ft.Out(0) != reflect.TypeFor[bool]() {
			panic(fmt.Sprintf("result is %T, want bool", ft.Out(0)))
		}
		fixup = func(vals []reflect.Value) []reflect.Value {
			return append(vals, reflect.Zero(reflect.TypeFor[error]()))
		}
	case 2:
		if ft.Out(0) != reflect.TypeFor[bool]() || ft.Out(1) != reflect.TypeFor[error]() {
			panic(fmt.Sprintf("results are %T, %T; want bool, error", ft.Out(0), ft.Out(1)))
		}
		fixup = func(vals []reflect.Value) []reflect.Value { return vals }
	default:
		panic(fmt.Sprintf("function returns %d values", ft.NumOut()))
	}
	fv := reflect.ValueOf(f)
	return reflect.MakeFunc(reflect.TypeFor[filter](), func(args []reflect.Value) []reflect.Value {
		if !args[0].IsValid() || args[0].Elem().Type() != ft.In(0) {
			return []reflect.Value{reflect.ValueOf(false), reflect.Zero(reflect.TypeFor[error]())}
		}
		return fixup(fv.Call([]reflect.Value{args[0].Elem()}))
	}).Interface().(filter)
}
